package main

import (
	"fmt"
	"github.com/tal-tech/go-zero/tools/goctl/api/spec"
	apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
	"github.com/tal-tech/go-zero/tools/goctl/config"
	"github.com/tal-tech/go-zero/tools/goctl/util"
	ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
	"github.com/tal-tech/go-zero/tools/goctl/util/format"
	"github.com/tal-tech/go-zero/tools/goctl/vars"
	"io"
	"strings"
)

const (
	contextFilename = "service_context"
	contextTemplate = `package svc

import (
	{{.configImport}}
	{{.typesImport}}
	"fmt"
	"github.com/tal-tech/go-zero/core/stores/cache"
	"github.com/tal-tech/go-zero/core/syncx"
	"gorm.io/driver/mysql"
	"gorm.io/driver/postgres"
	"github.com/kotlin2018/mbt"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"gorm.io/gorm/schema"
	"log"
	"strconv"
)

var Svc *ServiceContext

type ServiceContext struct {
	Config {{.config}}
	DB  *gorm.DB
	Cache  cache.Cache
	{{.middleware}}
}

func NewServiceContext(c {{.config}}){
	var (
		conn *gorm.DB
		dsn string
		err error
	)
	gormC := &gorm.Config{
		SkipDefaultTransaction: false, //启用事务
		NamingStrategy: schema.NamingStrategy{
			TablePrefix:   "",   //表前缀
			SingularTable: true, //使用单数表名
		},
		DryRun:                                   false,                               //禁止SQL空跑
		DisableForeignKeyConstraintWhenMigrating: true,                                //创建逻辑外键
		Logger:                                   logger.Default.LogMode(logger.Info), //输出 SQL语句
	}
	switch c.Database.DriverName {
	case "mysql":
		dsn = c.Database.User +":" + c.Database.Password + "@tcp(" + c.Database.Host + ":" + strconv.Itoa(c.Database.Port) + ")/" + c.Database.DBName +"?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
		conn, err = gorm.Open(mysql.New(mysql.Config{
			DSN:               dsn,
			DefaultStringSize: 171, //数据库varchar类型的默认值
		}), gormC)
	case "postgres":
		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable TimeZone=Asia/Shanghai",c.Database.Host, c.Database.Port, c.Database.User, c.Database.Password, c.Database.DBName)
		conn, err = gorm.Open(postgres.Open(dsn), gormC)
	}
	if err !=nil {
			log.Fatalln("connect database err",err)
    }
	mbt.New(&mbt.Database{
   		Pkg: "./",
   		DriverName: c.Database.DriverName,
   		DSN: dsn,
		Logger: &mbt.Logger{
				PrintSql: true,
      		 	PrintXml: false,
      		 	Path: "./log.log",
   		},
	}).Run()
	Svc = &ServiceContext{
		Config: c, 
		DB: conn,
		Cache: cache.New(c.CacheRedis, syncx.NewSingleFlight(), cache.NewStat(""), nil) ,
		{{.middlewareAssignment}}
	}
}

func (it *ServiceContext)AutoMigrate(){
	it.DB.AutoMigrate(
{{.types}})
}
`
)

func genServiceContext(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
	val, err := BuildTypes2(api.Types)
	filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
	if err != nil {
		return err
	}

	var middlewareStr string
	var middlewareAssignment string
	middlewares := getMiddleware(api)

	for _, item := range middlewares {
		middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
		name := strings.TrimSuffix(item, "Middleware")
		middlewareAssignment += fmt.Sprintf("%s: %s,\n", item,
			fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
	}

	configImport := "\"" + ctlutil.JoinPackages(rootPkg, configDir) + "\""
	typesImport := "\"" + ctlutil.JoinPackages(rootPkg, interval+typesPacket) + "\""
	if len(middlewareStr) > 0 {
		configImport += "\n\t\"" + ctlutil.JoinPackages(rootPkg, middlewareDir) + "\""
		configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL)
	}

	return genFile(fileGenConfig{
		dir:             dir,
		subdir:          contextDir,
		filename:        filename + ".go",
		templateName:    "contextTemplate",
		category:        category,
		templateFile:    contextTemplateFile,
		builtinTemplate: contextTemplate,
		data: map[string]string{
			"configImport":         configImport,
			"typesImport":          typesImport,
			"config":               "config.Config",
			"middleware":           middlewareStr,
			"middlewareAssignment": middlewareAssignment,
			"types":                val,
		},
	})
}

func BuildTypes2(types []spec.Type) (string, error) {
	var builder strings.Builder
	for _, tp := range types {
		if err := writeType2(&builder, tp); err != nil {
			return "", apiutil.WrapErr(err, "Type "+tp.Name()+" generate error")
		}
	}
	return builder.String(), nil
}

func writeType2(writer io.Writer, tp spec.Type) error {
	obj, ok := tp.(spec.DefineStruct)
	if !ok {
		return fmt.Errorf("unspport struct type: %s", tp.Name())
	}
	if len(obj.Docs)!=0 && len(obj.Members) !=0 {
		//if strings.Contains(strings.Join(obj.Docs, "Model"),"Model") || strings.Contains(strings.Join(obj.Docs, "M"),"M"){
		//	fmt.Fprintf(writer, "&types.%s{},\n", util.Title(tp.Name()))
		//}
		if strings.Contains(strings.Join(obj.Docs, "Model"),"Model"){
			fmt.Fprintf(writer, "&types.%s{},\n", util.Title(tp.Name()))
		}
	}
	return nil
}